import numpy as np
import faiss

class FaissKNeighbors:
    def __init__(self, k=1, res=None):
       
        self.index = None  # FAISS索引，用于存储训练数据
        self.y = None  # 训练数据的标签
        self.k = k  # 最近邻个数
        self.res = res  # FAISS GPU资源对象

    def fit(self, X, y):
       
        self.index = faiss.IndexFlatL2(X.shape[1])  # 初始化FAISS索引，使用L2距离
        
        # 如果提供了GPU资源，则将索引转移到GPU上
        if self.res is not None:
            self.index = faiss.index_cpu_to_gpu(self.res, 0, self.index)

        self.index.add(X.astype(np.float32))  # 将训练数据添加到索引
        self.y = y  

    def predict(self, X):
       
        # 搜索X中每个向量的k个最近邻
        distances, indices = self.index.search(X.astype(np.float32), self.k)
        votes = self.y[indices]  # 获取最近邻的标签
        
        # 通过投票机制确定最终的预测标签
        predictions = np.array([np.argmax(np.bincount(vote)) for vote in votes])
        return predictions

    def score(self, X, y):
        
        predictions = self.predict(X)  # 获取预测结果
        accuracy = np.mean(predictions == y)  # 计算准确率
        return accuracy
